{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/facebookresearch/fairo/blob/master/tutorials/how_to_build_a_simple_agent.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oo4u3tYsssHj"
   },
   "source": [
    "# How to build your own agent\n",
    "\n",
    "<p align=\"center\">\n",
    "   <img src=\"https://craftassist.s3-us-west-2.amazonaws.com/pubr/demo.gif\" />\n",
    "</p>\n",
    "\n",
    "## Build a simple agent\n",
    "In this tutorial, we will build a simple agent that catches a randomly moving bot in a 5x5 grid world.  The goal is to understand the high level organization of the droidlet agent.\n",
    "\n",
    "\n",
    "## Control logic\n",
    "\n",
    "The basic droidlet agent is made up of four major components: a perceptual API, a memory system, a controller, and a task queue. In each iteration of the event loop, the agent will run perceptual modules, updating the memory system with what it perceives, maybe place tasks into the task queue, and lastly, pop all finished tasks and then step the highest priority task.\n",
    "\n",
    "A typical event loop is as follows: \n",
    "\n",
    "> **while** True **do**\n",
    ">> run [perceptual modules](https://facebookresearch.github.io/fairo/perception.html), update [memory](https://facebookresearch.github.io/fairo/memory.html)\n",
    ">>\n",
    ">> step [controller](https://facebookresearch.github.io/fairo/controller.html)\n",
    ">>\n",
    ">> step highest priority [task](https://facebookresearch.github.io/fairo/tasks.html)\n",
    "\n",
    "<!---\n",
    "\n",
    "### **Perception**\n",
    "Perception modules\n",
    "\n",
    " is where the agent perceives the world it resides. Most of the perceptual modules in our example agents are visual: e.g. object detection and instance segmentation. You can customize your own perception modules and have it registered in the agent.\n",
    "\n",
    "All the information perception modules receive should go into agent's memory system. \n",
    "\n",
    "\n",
    "### **Memory System**\n",
    "Memory system serves as the interface for passing information between the various components of the agent. It consists of an AgentMemory object which is the entry point to the underlying SQL database and some MemoryNodes which represents a particular entity or event. It stores and organizes information like: \n",
    "- player info\n",
    "- time info\n",
    "- program info\n",
    "- task info\n",
    "- etc.\n",
    "\n",
    "### **Controller**\n",
    "\n",
    "Controller is where agent interpret commands, carry out dialogues and place tasks on the task stack.\n",
    "\n",
    "### **Task queue**\n",
    "Task queue stores tasks, which are (mostly) self-contained lower-level world interactions (e.g. Move, Point). For each event loop, one task is poped out of task queue and got executed by the agent.\n",
    "-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6h75P5HVgt4c"
   },
   "source": [
    "### Extend BaseAgent\n",
    "---\n",
    "\n",
    "The first you need to do is to extend the BaseAgent class and overwrite the following functions:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "xmLZEFKZn9V5"
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'BaseAgent' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-936f0390c187>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# grid_agent.py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mGridAgent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseAgent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworld\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworld\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mworld\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'BaseAgent' is not defined"
     ]
    }
   ],
   "source": [
    "# grid_agent.py\n",
    "class GridAgent(BaseAgent):\n",
    "    def __init__(self, world=None, opts=None):\n",
    "        self.world = world\n",
    "        self.pos = (0, 0, 0)\n",
    "        super(GridAgent, self).__init__(opts)\n",
    "\n",
    "    def init_memory(self):\n",
    "        pass\n",
    "\n",
    "    def init_perception(self):\n",
    "        pass\n",
    "\n",
    "    def init_controller(self):\n",
    "        pass\n",
    "    \n",
    "    def perceive(self):\n",
    "        pass\n",
    "\n",
    "    def get_incoming_chats(self):\n",
    "        pass\n",
    "\n",
    "    def controller_step(self):\n",
    "        pass\n",
    "    \n",
    "    def task_step(self, sleep_time=5):\n",
    "        pass\n",
    "    \n",
    "    def handle_exception(self, e):\n",
    "        pass\n",
    "    \n",
    "    def send_chat(self, chat):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CPv7-78T5866"
   },
   "source": [
    "We will go over each components in the following sections."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tZzHfwmooA_0"
   },
   "source": [
    "### Create a simple 5x5 grid world\n",
    "---\n",
    "\n",
    "Note that in the above ```___init___``` function we are passing a world to GridAgent, which is a simulated 5x5(x1) gridworld which hosts our agent. We also put a simple bot named \"target\" in it; our agent will need to catch it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0RcnBHlcoAGZ"
   },
   "outputs": [],
   "source": [
    "# world.py\n",
    "Bot = namedtuple(\"Bot\", \"entityId, name, pos, look\")\n",
    "\n",
    "class World:\n",
    "    def __init__(self, opts=None, spec=None):\n",
    "        target = Bot(1977, \"target\", Pos(3, 4, 0), Look(0, 0))\n",
    "        self.bots = [target]\n",
    "    \n",
    "    def get_bots(self, eid=None):\n",
    "        bots = self.bots if eid is None else [b for b in self.bots if b.entityId == eid]\n",
    "        return bots\n",
    "    \n",
    "    def remove_bot(self, eid):\n",
    "        self.bots[:] = [b for b in self.bots if b.entityId != eid]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wWVFDaDLpY8A"
   },
   "source": [
    "### Heuristic Perception\n",
    "---\n",
    "\n",
    "In order to catch the target, our agent needs to keep track of its location. We add a heuristic perception module that gets the position of all bots in the world and put them into memory.    \n",
    "\n",
    "In a more sophisticated agent, the perceptual models might be mediated by more in-depth heuristics or machine-learned models; but they would interface the Memory system in a similar way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kbIVQ_J2pgRF"
   },
   "outputs": [],
   "source": [
    "# heuristic_perception.py\n",
    "class HeuristicPerception:\n",
    "    def __init__(self, agent):\n",
    "        self.agent = agent\n",
    "\n",
    "    def perceive(self):\n",
    "        bots = self.agent.world.get_bots()\n",
    "        for bot in bots:\n",
    "            bot_node = self.agent.memory.get_player_by_eid(bot.entityId)\n",
    "            if bot_node is None:\n",
    "                memid = PlayerNode.create(self.agent.memory, bot)\n",
    "                bot_node = PlayerNode(self.agent.memory, memid)\n",
    "                self.agent.memory.tag(memid, \"bot\")\n",
    "            bot_node.update(self.agent.memory, bot, bot_node.memid)\n",
    "            \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SwB9CyUqpuwz"
   },
   "source": [
    "### Memory Module\n",
    "---\n",
    "\n",
    "To store and organize all the information, the agent needs a Memory Module. Here we just use [AgentMemory](https://facebookresearch.github.io/fairo/memory.html#base_agent.sql_memory.AgentMemory) of base_agent and use [PlayerNode](https://facebookresearch.github.io/fairo/memory.html#memorynodes) to represent the bot entity. You can also extend them and define your own Memory Nodes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GNP2yOQ7qSyh"
   },
   "source": [
    "### Tasks\n",
    "---\n",
    "\n",
    "A [Task](https://facebookresearch.github.io/fairo/tasks.html) is a world interaction whose implementation might vary from platform to platform. \n",
    "\n",
    "<!--It usually consists of a target (e.g. a certain position the agent wants to move to, or an entity the agent wants to destroy) and a stop condition. We usually break it down to several small steps and do one at a time in step function until it is finished (stop condition is met). -->\n",
    "\n",
    "#### **Simple Catch Task**\n",
    "\n",
    "We are going to create a simple Catch Task for our agent. We break it into two smaller subtasks: a Move Task and a Grab Task. \n",
    "\n",
    "In Move Task, the agent will simply head to a given position. The stop condition is when the agent is at the exact location of the target. It will move one block at a time to get close to the target until the stop condition is met.\n",
    "\n",
    "In Grab Task, the agent will simply grab the target physically. The stop condition is when the target has disappeared from the world."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nuN1YcW3qWfV"
   },
   "outputs": [],
   "source": [
    "# tasks.py\n",
    "class Move(Task):\n",
    "    def __init__(self, agent, task_data):\n",
    "        super(Move, self).__init__()\n",
    "        self.target = task_data[\"target\"]\n",
    "    \n",
    "    def step(self, agent):\n",
    "        super().step(agent)\n",
    "        if self.finished:\n",
    "            return\n",
    "        agent.move(self.target[0], self.target[1], self.target[2])\n",
    "        self.finished = True\n",
    "\n",
    "\n",
    "class Grab(Task):\n",
    "    def __init__(self, agent, task_data):\n",
    "        super(Grab, self).__init__()\n",
    "        self.target_eid = task_data[\"target_eid\"]\n",
    "\n",
    "    def step(self, agent):\n",
    "        super().step(agent)\n",
    "        if self.finished:\n",
    "            return\n",
    "\n",
    "        if len(agent.world.get_bots(eid=self.target_eid)) > 0:\n",
    "            agent.catch(self.target_eid)\n",
    "        else:\n",
    "            self.finished = True\n",
    "\n",
    "\n",
    "\n",
    "class Catch(Task):\n",
    "    def __init__(self, agent, task_data):\n",
    "        super(Catch, self).__init__()\n",
    "        self.target_memid = task_data[\"target_memid\"]\n",
    "    \n",
    "    def step(self, agent):\n",
    "        super().step(agent)\n",
    "        if self.finished:\n",
    "            return\n",
    "\n",
    "        # retrieve target info from memory:\n",
    "        target_mem = agent.memory.get_mem_by_id(self.target_memid)\n",
    "                    \n",
    "        # first get close to the target, one block at a time\n",
    "        tx, ty, tz = target_mem.get_pos()\n",
    "        x, y, z = agent.get_pos()\n",
    "        if np.linalg.norm(np.subtract((x, y, z), (tx, ty, tz))) > 0.:\n",
    "            if x != tx:\n",
    "                x += 1 if x - tx < 0 else -1\n",
    "            else:\n",
    "                y += 1 if y - ty < 0 else -1\n",
    "            move_task = Move(agent, {\"target\": (x, y, z)})\n",
    "            agent.memory.add_tick()\n",
    "            self.add_child_task(move_task, agent)\n",
    "            return\n",
    "\n",
    "        # once target is within reach, catch it!\n",
    "        grab_task = Grab(agent, {\"target_eid\": target_mem.eid})\n",
    "        agent.memory.add_tick()\n",
    "        self.add_child_task(grab_task, agent)\n",
    "        self.finished = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DcwcUXkwrARY"
   },
   "source": [
    "### Controller\n",
    "---\n",
    "\n",
    "The [Controller](https://facebookresearch.github.io/fairo/controller.html) decides which Tasks (if any) to put on the stack.  In the [craftassist](https://github.com/facebookresearch/fairo/blob/main/craftassist/agent/craftassist_agent.py) and [locobot](https://github.com/facebookresearch/fairo/blob/main/locobot/agent/locobot_agent.py) agents, the controller is itself a modular, multipart system.  \n",
    "\n",
    "In this tutorial, to keep things simple and self contained, the controller will just push the Catch task onto the stack.   \n",
    "\n",
    "For more in-depth discussion about Controllers we use, look [here](https://facebookresearch.github.io/fairo/controller.html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eBVi6KEhrYK6"
   },
   "outputs": [],
   "source": [
    "# grid_agent.py\n",
    "class GridAgent(BaseAgent):\n",
    "    ...\n",
    "    ...\n",
    "\n",
    "    def controller_step(self):\n",
    "        bot_memids = self.memory.get_memids_by_tag(\"bot\")\n",
    "        if self.memory.task_stack_peek() is None:\n",
    "            if bot_memids:            \n",
    "                task_data = {\"target_memid\": bot_memids[0]}\n",
    "                self.memory.task_stack_push(Catch(self, task_data))\n",
    "            else:\n",
    "                exit()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fCrFuEvmqw--"
   },
   "source": [
    "### Task Step\n",
    "---\n",
    "\n",
    "Here the agent steps the topmost Task on the Stack.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tyndePy0reAx"
   },
   "outputs": [],
   "source": [
    "# grid_agent.py\n",
    "class GridAgent(BaseAgent):\n",
    "    ...\n",
    "    ...\n",
    "    \n",
    "    def task_step(self, sleep_time=5):\n",
    "        # clear finsihed tasks from stack\n",
    "        while (\n",
    "            self.memory.task_stack_peek() and self.memory.task_stack_peek().task.check_finished()\n",
    "        ):\n",
    "            self.memory.task_stack_pop()\n",
    "\n",
    "        # do nothing if there's no task\n",
    "        if self.memory.task_stack_peek() is None:\n",
    "            return\n",
    "\n",
    "        # If something to do, step the topmost task\n",
    "        task_mem = self.memory.task_stack_peek()\n",
    "        if task_mem.memid != self.last_task_memid:\n",
    "            self.last_task_memid = task_mem.memid\n",
    "        task_mem.task.step(self)\n",
    "        self.memory.task_stack_update_task(task_mem.memid, task_mem.task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iLQj9knbr3xy"
   },
   "source": [
    "### Put it together\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tlYvhMSHr96Y"
   },
   "outputs": [],
   "source": [
    "# grid_agent.py\n",
    "\n",
    "class GridAgent(BaseAgent):\n",
    "    def __init__(self, world=None, opts=None):\n",
    "        self.world = world\n",
    "        self.last_task_memid = None\n",
    "        self.pos = (0, 0, 0)\n",
    "        super(GridAgent, self).__init__(opts)\n",
    "\n",
    "    def init_memory(self):\n",
    "        self.memory = AgentMemory()\n",
    "\n",
    "    def init_perception(self):\n",
    "        self.perception_modules = {}\n",
    "        self.perception_modules['heuristic'] = HeuristicPerception(self)\n",
    "\n",
    "    def init_controller(self):\n",
    "        pass\n",
    "    \n",
    "    def perceive(self):\n",
    "        self.world.step() # update world state\n",
    "        for perception_module in self.perception_modules.values():\n",
    "            perception_module.perceive()\n",
    "\n",
    "    def controller_step(self):\n",
    "        bot_memids = self.memory.get_memids_by_tag(\"bot\")\n",
    "        if self.memory.task_stack_peek() is None:\n",
    "            if bot_memids:            \n",
    "                task_data = {\"target_memid\": bot_memids[0]}\n",
    "                self.memory.task_stack_push(Catch(self, task_data))\n",
    "                logging.info(f\"pushed Catch Task of bot with memid: {bot_memids[0]}\")\n",
    "            else:\n",
    "                exit()\n",
    "    \n",
    "    def task_step(self, sleep_time=5):\n",
    "        while (\n",
    "            self.memory.task_stack_peek() and self.memory.task_stack_peek().task.check_finished()\n",
    "        ):\n",
    "            self.memory.task_stack_pop()\n",
    "\n",
    "        # do nothing if there's no task\n",
    "        if self.memory.task_stack_peek() is None:\n",
    "            return\n",
    "\n",
    "        # If something to do, step the topmost task\n",
    "        task_mem = self.memory.task_stack_peek()\n",
    "        if task_mem.memid != self.last_task_memid:\n",
    "            logging.info(\"Starting task {}\".format(task_mem.task))\n",
    "            self.last_task_memid = task_mem.memid\n",
    "        task_mem.task.step(self)\n",
    "        self.memory.task_stack_update_task(task_mem.memid, task_mem.task)\n",
    "        self.world.visualize(self)\n",
    "\n",
    "    \"\"\"physical interfaces\"\"\"\n",
    "    def get_pos(self):\n",
    "        return self.pos\n",
    "    \n",
    "    def move(self, x, y, z):\n",
    "        self.pos = (x, y, z)\n",
    "        return self.pos\n",
    "    \n",
    "    def catch(self, target_eid):\n",
    "        bots = self.world.get_bots(eid=target_eid)\n",
    "        if len(bots) > 0:\n",
    "            bot = bots[0]\n",
    "            if np.linalg.norm(np.subtract(self.pos, bot.pos)) <1.0001:\n",
    "                self.world.remove_bot(target_eid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Om3QYzusRrKU"
   },
   "source": [
    "### Run the agent\n",
    "\n",
    "To run the agent, you need to create a runtime populated with files we just created. Luckily we have already prepared one for you. Simply run the following command to pull it and install required packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PaZJIVdaRVav"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/facebookresearch/fairo.git && cd examples/grid && pip install -r requirements.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TQ99VXaCsTsR"
   },
   "source": [
    "### Run it now!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xfoZXJHHAkit"
   },
   "outputs": [],
   "source": [
    "%run agent/grid_agent.py"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "how to build a simple agent.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
